# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test sharded loader"""
# pylint: disable=missing-docstring
import json
import tempfile

import numpy as np

import tvm
import tvm.testing
from tvm import dlight as dl
from tvm import relax as rx
from tvm_ffi import register_global_func
from tvm.contrib import tvmjs
from tvm.runtime import ShapeTuple
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.target import Target
from tvm.contrib import tvmjs


@register_global_func("tests.disco.shard_dim_0", override=True)
def _shard_dim_0(src, num_shards, tgt):
    s_0, s_1 = src.shape
    tgt.copyfrom(src.numpy().reshape(num_shards, s_0 // num_shards, s_1))


@register_global_func("tests.disco.shard_dim_1", override=True)
def _shard_dim_1(src, num_shards, tgt):
    s_0, s_1 = src.shape
    tgt.copyfrom(src.numpy().reshape(s_0, num_shards, s_1 // num_shards).transpose(1, 0, 2))


@register_global_func("tests.disco.shard_qkv_0", override=True)
def _shard_qkv_0(src, num_shards, q_heads, kv_heads, tgt):
    total_dim, hidden_size = src.shape
    head_dim = total_dim // (q_heads + kv_heads + kv_heads)
    q_dim = q_heads * head_dim
    kv_dim = kv_heads * head_dim
    w_q = src.numpy()[:q_dim, :].reshape(
        num_shards,
        q_heads // num_shards,
        head_dim,
        hidden_size,
    )
    w_k = src.numpy()[q_dim : q_dim + kv_dim, :].reshape(
        num_shards,
        kv_heads // num_shards,
        head_dim,
        hidden_size,
    )
    w_v = src.numpy()[q_dim + kv_dim :, :].reshape(
        num_shards,
        kv_heads // num_shards,
        head_dim,
        hidden_size,
    )
    w_qkv = np.concatenate([w_q, w_k, w_v], axis=1)
    tgt.copyfrom(w_qkv)


@register_global_func("tests.disco.shard_qkv_1", override=True)
def _shard_qkv_1(src, tgt):
    s, _, _, h = src.shape  # pylint: disable=invalid-name
    tgt.copyfrom(src.numpy().reshape(s, -1, h))


def _create_loader(sess, path, param_dict, shard_info):
    path_tensor_cache = path + "/tensor-cache.json"
    tvmjs.dump_tensor_cache(param_dict, path, encode_format="raw")
    with open(path_tensor_cache, "r", encoding="utf-8") as i_f:
        tensor_cache = i_f.read()
    loader_create = sess.get_global_func("runtime.disco.ShardLoader")
    loader = loader_create(path_tensor_cache, tensor_cache, json.dumps(shard_info), None)
    return loader


def _simulate_presharded_weights(base_path, param_dict, num_shards, shard_info):
    """Create fake weights to simulate those produced MLC-LLM's pre-sharding"""

    sharded_params = {}

    for key, ndarray in param_dict.items():
        assert key in shard_info, f"ShardInfo lacks shard info about param: {key}"
        shard_dim = shard_info[key]
        sharded_params[key] = [
            tvm.runtime.tensor(np_shard)
            for np_shard in np.split(ndarray, num_shards, axis=shard_dim)
        ]

    # Re-order so that the parameter order is sorted first by shard,
    # then by parameter.  This matches the ordering used by MLC-LLM,
    # and avoids having *.bin files that must be accessed by more than
    # one worker.
    sharded_params = {
        f"{key}_shard-{i+1}-of-{num_shards}": shards[i]
        for i in range(num_shards)
        for key, shards in sharded_params.items()
    }

    tvmjs.dump_tensor_cache(
        sharded_params,
        base_path,
        encode_format="raw",
    )


def test_load_shard():
    devices = [0, 1]
    num_shards = len(devices)
    param_dict = {
        "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
        "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
    }
    shard_info = {
        "x_0": [
            [
                "tests.disco.shard_dim_1",
                [(num_shards, 64, 64), "float16"],
                num_shards,
            ],
        ],
        "x_1": [
            [
                "tests.disco.shard_dim_0",
                [(num_shards, 16, 128), "float32"],
                num_shards,
            ]
        ],
    }
    with tempfile.TemporaryDirectory() as path:
        sess = di.ThreadedSession(num_workers=len(devices))
        sess.init_ccl("nccl", *devices)
        loader = _create_loader(sess, path, param_dict, shard_info)
        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoad")
        d_0 = loader_load(loader, ShapeTuple([0]))
        d_1 = loader_load(loader, ShapeTuple([1]))
        np.testing.assert_equal(
            param_dict["x_0"][:, 0:64],
            d_0.debug_get_from_remote(0).numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_0"][:, 64:128],
            d_0.debug_get_from_remote(1).numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_1"][0:16, :],
            d_1.debug_get_from_remote(0).numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_1"][16:32, :],
            d_1.debug_get_from_remote(1).numpy(),
        )


def _create_presharded_loader(sess, path):
    path_tensor_cache = path + "/tensor-cache.json"
    with open(path_tensor_cache, "r", encoding="utf-8") as i_f:
        tensor_cache = i_f.read()
    loader_create = sess.get_global_func("runtime.disco.ShardLoader")
    loader = loader_create(path_tensor_cache, tensor_cache, json.dumps({}), None)
    return loader


def test_load_presharded():
    devices = [0, 1]
    param_dict = {
        "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
        "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
    }
    shard_info = {
        "x_0": 1,
        "x_1": 0,
    }

    with tempfile.TemporaryDirectory() as path:
        _simulate_presharded_weights(path, param_dict, len(devices), shard_info)
        sess = di.ThreadedSession(num_workers=len(devices))
        sess.init_ccl("nccl", *devices)

        loader = _create_presharded_loader(sess, path)
        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadPresharded")

        d_0 = loader_load(loader, ShapeTuple([0]))
        d_1 = loader_load(loader, ShapeTuple([1]))

        np.testing.assert_equal(
            param_dict["x_0"][:, 0:64],
            d_0.debug_get_from_remote(0).numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_0"][:, 64:128],
            d_0.debug_get_from_remote(1).numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_1"][0:16, :],
            d_1.debug_get_from_remote(0).numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_1"][16:32, :],
            d_1.debug_get_from_remote(1).numpy(),
        )


def test_load_shard_in_relax():
    devices = [0, 1]
    num_shards = len(devices)
    param_dict = {
        "x_0": np.random.uniform(size=[64, 128]).astype("float16"),
        "x_1": np.random.uniform(size=[32, 128]).astype("float32"),
    }
    shard_info = {
        "x_0": [
            [
                "tests.disco.shard_dim_1",
                [(num_shards, 64, 64), "float16"],
                num_shards,
            ],
        ],
        "x_1": [
            [
                "tests.disco.shard_dim_0",
                [(num_shards, 16, 128), "float32"],
                num_shards,
            ]
        ],
    }

    # pylint: disable=invalid-name
    @I.ir_module
    class Module:  # pylint: disable=too-few-public-methods
        @R.function
        def main(
            loader: R.Object,
        ) -> R.Tuple(R.Tensor((64, 64), "float32"), R.Tensor((16, 128), "float32")):
            R.func_attr({"global_symbol": "main"})
            with R.dataflow():
                lv0: R.Tensor((64, 64), "float32") = R.call_pure_packed(
                    "runtime.disco.ShardLoaderLoad",
                    loader,
                    R.shape([0]),
                    sinfo_args=R.Tensor((64, 64), "float32"),
                )
                lv1: R.Tensor((16, 128), "float32") = R.call_pure_packed(
                    "runtime.disco.ShardLoaderLoad",
                    loader,
                    R.shape([1]),
                    sinfo_args=R.Tensor((16, 128), "float32"),
                )
                lv2 = R.tuple(lv0, lv1)
                R.output(lv2)
            return lv2

    # pylint: enable=invalid-name
    def relax_build(mod, target):
        with target:
            mod = rx.get_pipeline("zero")(mod)  # pylint: disable=no-value-for-parameter
            return tvm.compile(mod, target="cuda")

    target = Target(
        {
            "kind": "cuda",
            "max_shared_memory_per_block": 49152,
            "max_threads_per_block": 1024,
            "thread_warp_size": 32,
            "registers_per_block": 65536,
            "arch": "sm_80",
        }
    )
    with tempfile.TemporaryDirectory() as tmpdir:
        dso_path = tmpdir + "/test.so"
        sess = di.ThreadedSession(num_workers=len(devices))
        sess.init_ccl("nccl", *devices)
        relax_build(Module, target).export_library(dso_path)

        mod = sess.load_vm_module(dso_path)
        loader = _create_loader(sess, tmpdir, param_dict, shard_info)
        result = mod["main"](loader)
        np.testing.assert_equal(
            param_dict["x_0"][:, 0:64],
            result.debug_get_from_remote(0)[0].numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_0"][:, 64:128],
            result.debug_get_from_remote(1)[0].numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_1"][0:16, :],
            result.debug_get_from_remote(0)[1].numpy(),
        )
        np.testing.assert_equal(
            param_dict["x_1"][16:32, :],
            result.debug_get_from_remote(1)[1].numpy(),
        )


def test_load_shard_all():
    devices = [0, 1]
    num_shards = len(devices)
    param_dict = {
        "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
        "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
    }
    shard_info = {
        "param_0": [
            [
                "tests.disco.shard_dim_1",
                [(num_shards, 64, 64), "float16"],
                num_shards,
            ],
        ],
        "param_1": [
            [
                "tests.disco.shard_dim_0",
                [(2, 16, 128), "float32"],
                num_shards,
            ]
        ],
    }
    with tempfile.TemporaryDirectory() as path:
        sess = di.ThreadedSession(num_workers=len(devices))
        sess.init_ccl("nccl", *devices)
        loader = _create_loader(sess, path, param_dict, shard_info)
        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll")
        params = loader_load(loader)
        p_0 = params.debug_get_from_remote(0)
        p_1 = params.debug_get_from_remote(1)
        np.testing.assert_equal(param_dict["param_0"][:, 0:64], p_0[0].numpy())
        np.testing.assert_equal(param_dict["param_0"][:, 64:128], p_1[0].numpy())
        np.testing.assert_equal(param_dict["param_1"][0:16, :], p_0[1].numpy())
        np.testing.assert_equal(param_dict["param_1"][16:32, :], p_1[1].numpy())


def test_load_all_presharded():
    devices = [0, 1]
    num_shards = len(devices)
    param_dict = {
        "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
        "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
    }
    shard_info = {
        "param_0": 0,
        "param_1": 1,
    }
    with tempfile.TemporaryDirectory() as path:
        _simulate_presharded_weights(path, param_dict, len(devices), shard_info)

        sess = di.ThreadedSession(num_workers=len(devices))
        sess.init_ccl("nccl", *devices)
        loader = _create_presharded_loader(sess, path)
        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAllPresharded")
        params = loader_load(loader)

        p_0 = params.debug_get_from_remote(0)
        p_1 = params.debug_get_from_remote(1)

        np.testing.assert_equal(param_dict["param_0"][0:32, :], p_0[0].numpy())
        np.testing.assert_equal(param_dict["param_0"][32:64, :], p_1[0].numpy())
        np.testing.assert_equal(param_dict["param_1"][:, 0:64], p_0[1].numpy())
        np.testing.assert_equal(param_dict["param_1"][:, 64:128], p_1[1].numpy())


def test_load_shard_broadcast():
    devices = [0, 1]
    param_dict = {
        "param_0": np.random.uniform(size=[64, 128]).astype("float16"),
        "param_1": np.random.uniform(size=[32, 128]).astype("float32"),
    }
    shard_info = {}
    with tempfile.TemporaryDirectory() as path:
        sess = di.ThreadedSession(num_workers=len(devices))
        sess.init_ccl("nccl", *devices)
        loader = _create_loader(sess, path, param_dict, shard_info)
        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll")
        params = loader_load(loader)
        p_0 = params.debug_get_from_remote(0)
        p_1 = params.debug_get_from_remote(1)
        np.testing.assert_equal(param_dict["param_0"], p_0[0].numpy())
        np.testing.assert_equal(param_dict["param_0"], p_1[0].numpy())
        np.testing.assert_equal(param_dict["param_1"], p_0[1].numpy())
        np.testing.assert_equal(param_dict["param_1"], p_1[1].numpy())


def test_load_qkv_proj_shard():  # pylint: disable=too-many-locals
    devices = [0, 1]
    num_shards = len(devices)
    q_heads = 8
    kv_heads = 10
    head_dim = 10
    hidden_size = 20
    w_q = np.random.uniform(size=[q_heads * head_dim, hidden_size]).astype("float16")
    w_k = np.random.uniform(size=[kv_heads * head_dim, hidden_size]).astype("float16")
    w_v = np.random.uniform(size=[kv_heads * head_dim, hidden_size]).astype("float16")
    w_qkv = np.concatenate([w_q, w_k, w_v], axis=0)
    param_dict = {"w_qkv": w_qkv}
    np_qkv = np.concatenate(
        [
            w_q.reshape((num_shards, q_heads // num_shards, head_dim, hidden_size)),
            w_k.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)),
            w_v.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)),
        ],
        axis=1,
    ).reshape((num_shards, -1, hidden_size))

    shard_info = {
        "w_qkv": [
            [
                "tests.disco.shard_qkv_0",
                [
                    (num_shards, (q_heads + kv_heads * 2) // num_shards, head_dim, hidden_size),
                    "float16",
                ],
                num_shards,
                q_heads,
                kv_heads,
            ],
            [
                "tests.disco.shard_qkv_1",
                [
                    (num_shards, (q_heads + kv_heads * 2) // num_shards * head_dim, hidden_size),
                    "float16",
                ],
            ],
        ],
    }

    with tempfile.TemporaryDirectory() as path:
        sess = di.ThreadedSession(num_workers=len(devices))
        sess.init_ccl("nccl", *devices)
        loader = _create_loader(sess, path, param_dict, shard_info)
        loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoad")
        d_0 = loader_load(loader, ShapeTuple([0]))
        np.testing.assert_equal(
            np_qkv[0],
            d_0.debug_get_from_remote(0).numpy(),
        )
        np.testing.assert_equal(
            np_qkv[1],
            d_0.debug_get_from_remote(1).numpy(),
        )


if __name__ == "__main__":
    tvm.testing.main()
